{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2维数据转换为3维\n",
    "\n",
    "import numpy as np\n",
    "import torch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([17520, 17])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bikes_numpy = np.loadtxt('../data/p1ch4/bike-sharing-dataset/hour-fixed.csv',\n",
    "                        dtype = np.float32,\n",
    "                        delimiter=',',\n",
    "                        skiprows =1,\n",
    "                        converters={1:lambda x: float(x[8:10])})  # 将日期转换成当月的第几天的形式，放到1列\n",
    "bikes = torch.from_numpy(bikes_numpy)\n",
    "bikes.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 3.0000e+00, 1.3000e+01,\n",
       "         1.6000e+01],\n",
       "        [2.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 8.0000e+00, 3.2000e+01,\n",
       "         4.0000e+01],\n",
       "        [3.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 5.0000e+00, 2.7000e+01,\n",
       "         3.2000e+01],\n",
       "        ...,\n",
       "        [1.7377e+04, 3.1000e+01, 1.0000e+00,  ..., 7.0000e+00, 8.3000e+01,\n",
       "         9.0000e+01],\n",
       "        [1.7378e+04, 3.1000e+01, 1.0000e+00,  ..., 1.3000e+01, 4.8000e+01,\n",
       "         6.1000e+01],\n",
       "        [1.7379e+04, 3.1000e+01, 1.0000e+00,  ..., 1.2000e+01, 3.7000e+01,\n",
       "         4.9000e+01]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bikes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "数据共有17列：\n",
    "0. Index of record: instant\n",
    "1. Day of month: day\n",
    "2. Season: season (1: spring, 2: summer, 3: fall, 4: winter)\n",
    "3. Year: yr (0: 2011, 1: 2012)\n",
    "4. Month: mnth (1 to 12)\n",
    "5. Hour: hr (0 to 23)\n",
    "6. Holiday status: holiday\n",
    "7. Day of the week: weekday\n",
    "8. Working day status: workingday\n",
    "9. Weather situation: weathersit (1: clear, 2:mist, 3: light rain/snow, 4: heavy rain/snow)\n",
    "10. Temperature in °C: temp\n",
    "11. Perceived temperature in °C: atemp\n",
    "12. Humidity: hum\n",
    "13. Wind speed: windspeed\n",
    "14. Number of casual users: casual\n",
    "15. Number of registered users: registered\n",
    "16. Count of rental bikes: cnt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([17520, 17]), (17, 1))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 转换为 N × C × L 的shape\n",
    "# C 是 17个特征（行），L是24个小时\n",
    "\n",
    "bikes.shape,bikes.stride()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([730, 24, 17]), (408, 17, 1))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "daily_bikes = bikes.view(-1,24,bikes.shape[1])   # 转换为天，小时，17(将17520分解)\n",
    "daily_bikes.shape,daily_bikes.stride()           # 返回的是新的张量，但是存储不变\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([730, 17, 24]), (408, 1, 17))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 再次转置\n",
    "daily_bikes = daily_bikes.transpose(1,2)\n",
    "daily_bikes.shape,daily_bikes.stride()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取天气情况，把它作为分类变量\n",
    "first_day = bikes[:24].long()\n",
    "weather_onehot = torch.zeros(first_day.shape[0],4)  # 行为24个小时，列为4种情况\n",
    "first_day[:,9]                                      # bikes中第9列为天气情况"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 1., 0., 0.]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 把天气情况对应到weather_onehot\n",
    "weather_onehot.scatter_(dim=1,\n",
    "                       index = first_day[:,9].unsqueeze(1).long()-1,   #减1是因为天气情况为1-4，而指数为0-3,应该匹配\n",
    "                       value = 1.0)  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.0000,  1.0000,  1.0000,  0.0000,  1.0000,  0.0000,  0.0000,  6.0000,\n",
       "          0.0000,  1.0000,  0.2400,  0.2879,  0.8100,  0.0000,  3.0000, 13.0000,\n",
       "         16.0000,  1.0000,  0.0000,  0.0000,  0.0000]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用cat函数将矩阵连接到原始数据集上，先拿第一天举例\n",
    "\n",
    "torch.cat((bikes[:24],weather_onehot),1)[:1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([730, 4, 24])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "# 在整体数据集上进行\n",
    "\n",
    "daily_weather_onehot = torch.zeros(daily_bikes.shape[0],4,daily_bikes.shape[2])\n",
    "# 把所有的天气数据转换为onehot\n",
    "daily_weather_onehot.scatter(dim = 1,\n",
    "                            index = daily_bikes[:,9,:].long().unsqueeze(1) -1 ,\n",
    "                            value = 1.0)\n",
    "\n",
    "daily_weather_onehot.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# onehot 与原数据连接\n",
    "daily_bikes = torch.cat((daily_bikes,daily_weather_onehot),dim = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 还可以将天气数据看作连续变量\n",
    "daily_bikes[:,9,:] = (daily_bikes[:,9,:]-1.0)/3.0   # 1-4 转换成0-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 调整温度信息，映射到0-1\n",
    "temp = daily_bikes[:,10,:]\n",
    "temp_min = torch.min(temp)\n",
    "temp_max = torch.max(temp)\n",
    "daily_bikes[:,10,:] = (daily_bikes[:,10,:]-temp_min)/(temp_max-temp_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 或者减去均值后除以标准差\n",
    "\n",
    "daily_bikes[:,10,:] = (daily_bikes[:,10,:]-torch.mean(temp))/torch.std(temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
